
import os
import json
import sys
import re
import random
import numpy as np
from openai import AzureOpenAI
from tqdm import tqdm
import argparse
# Import persona prompt dictionaries
from personas import (
    persona_prompts_10,
    persona_prompts_20,
    persona_prompts_30,
    persona_prompts_40,
    persona_prompts_50,
    persona_prompts_60,
    persona_prompts_70,
    persona_prompts_80,
    persona_prompts_90,
    persona_prompts_100,
)

# Ordered list of (label, persona_dict) so we can iterate deterministically
PERSONA_SETS: list[tuple[str, dict]] = [
    ("10", persona_prompts_10),
    ("20", persona_prompts_20),
    ("30", persona_prompts_30),
    ("40", persona_prompts_40),
    ("50", persona_prompts_50),
    ("60", persona_prompts_60),
    ("70", persona_prompts_70),
    ("80", persona_prompts_80),
    ("90", persona_prompts_90),
    ("100", persona_prompts_100),
]

# (Filtering will occur AFTER argument parsing further below)

parser = argparse.ArgumentParser()
parser.add_argument('--count_personas', action='store_true', help='Output the number of personas')
parser.add_argument('--start', type=int, default=0, help='Start index for dataset slicing')
parser.add_argument('--end', type=int, default=None, help='End index for dataset slicing (inclusive)')
parser.add_argument('--output_dir', type=str, default='results', help='Directory to save per-job JSON outputs')
# NEW ARG: comma-separated list of repetition counts
parser.add_argument('--runs_list', type=str, default='1',
                    help='Comma-separated list indicating how many times to repeat the prediction for each datapoint (default "1" for a single run per persona)')
# NEW ARGS FOR GMO EVALUATION
parser.add_argument('--gmo', action='store_true', help='Run GMO CTR/CPA evaluation mode (ads) instead of WebAES')
parser.add_argument('--dataset_dir', type=str, default='/path/to/ctr_dataset/',
                    help='Directory containing *.jsonl files for GMO evaluation')
# Prior-run subset reuse
parser.add_argument('--limit', type=int, default=None, help='Total number of samples to evaluate across all dataset files (evenly sampled)')
parser.add_argument('--use_np_subset', action='store_true', help='Use the same datapoints as the prior GPT-CTR-NP run20 merged results')
# Seed for deterministic sampling across personas and run counts
parser.add_argument('--seed', type=int, default=42, help='Random seed to ensure consistent sampling across runs')
# Allow restricting evaluation to specific persona set label(s)
parser.add_argument('--persona_label', type=str, default=None,
                    help='Comma-separated list of persona set labels to evaluate (e.g., "20,40"). If omitted, all sets are evaluated.')
args = parser.parse_args()

# ---------------------------------------------------------------------------
# Deterministic seeding
# ---------------------------------------------------------------------------
random.seed(args.seed)
np.random.seed(args.seed)

# Parse runs_list into a list of ints and ensure they are positive
runs_list = [int(x) for x in args.runs_list.split(',') if x.strip()]
runs_list = [r for r in runs_list if r > 0]
if not runs_list:
    raise ValueError("--runs_list must contain at least one positive integer")
# After parsing args
output_dir = os.path.abspath(args.output_dir)
os.makedirs(output_dir, exist_ok=True)
args.output_dir = output_dir  # overwrite to absolute for consistency

# ---------------------------------------------------------------------------
# Optional filtering of persona sets based on --persona_label (after args)
# ---------------------------------------------------------------------------

if args.persona_label:
    requested_labels: set[str] = {lbl.strip() for lbl in args.persona_label.split(',') if lbl.strip()}
    PERSONA_SETS = [(lbl, d) for lbl, d in PERSONA_SETS if lbl in requested_labels]
    if not PERSONA_SETS:
        print(f"[ERROR] No persona sets matched --persona_label={args.persona_label}", file=sys.stderr)
        sys.exit(1)

# Path to the previously merged results file from gpt_ctr_np
NP_SUBSET_PATH = "/path/to/final_gmo_results_runs20_merged.json"

def _load_np_subset_prompts(path: str) -> set[str]:
    """Return a set of prompt strings present in the prior NP results file."""
    if not os.path.isfile(path):
        print(f"[ERROR] NP subset file not found: {path}", file=sys.stderr)
        sys.exit(1)
    try:
        with open(path, 'r', encoding='utf-8') as f_in:
            data = json.load(f_in)
    except Exception as e:
        print(f"[ERROR] Failed to load NP subset file {path}: {e}", file=sys.stderr)
        sys.exit(1)
    prompts = {rec.get('prompt', '') for rec in data if isinstance(rec, dict)}
    return prompts

# ---------------------------------------------------------------------------
# NOTE: Legacy WebAES (website likeability) code and prompts have been removed.
# This script now focuses solely on GMO CPA-percentile evaluation.
# ---------------------------------------------------------------------------

# (Image handling utilities removed – not required for CPA evaluation)

# Azure OpenAI Configuration
api_version = "2024-02-15-preview"
config_dict = {
    'api_key': "YOUR_OPENAI_API_KEY",
    'api_version': api_version,
    'azure_endpoint': "https://your-azure-openai-endpoint/"
}
# ----------------------------- Helper Functions -----------------------------

def _verbalize_persona(prompt: str, persona_system_prompt: str) -> str:
    """Call Azure OpenAI with the given persona system prompt."""
    client = AzureOpenAI(
        api_key=config_dict['api_key'],
        api_version=config_dict['api_version'],
        azure_endpoint=config_dict['azure_endpoint'],
    )
    messages = [
        {"role": "system", "content": persona_system_prompt},
        {"role": "user", "content": prompt},
    ]
    resp = client.chat.completions.create(
        model='gpt-4o',
        messages=messages,
        max_tokens=350,
        temperature=0.85,
        n=1,
    )
    return resp.choices[0].message.content.strip()

def _sample_gmo_records(dataset_dir: str, total_limit: int | None):
    """Randomly sample records from each *.jsonl file in `dataset_dir`.

    If `total_limit` is provided, samples are taken as `total_limit // n_files` per file.
    Otherwise, all records from each file are returned.
    """
    file_paths = [os.path.join(dataset_dir, fp) for fp in os.listdir(dataset_dir) if fp.endswith('.jsonl')]
    if not file_paths:
        print(f"[ERROR] No .jsonl files found in {dataset_dir}", file=sys.stderr)
        sys.exit(1)

    random.shuffle(file_paths)  # shuffle to avoid ordering bias
    records = []
    per_file = None
    if total_limit is not None and total_limit > 0:
        per_file = max(1, total_limit // len(file_paths))

    for fp in file_paths:
        with open(fp, 'r', encoding='utf-8') as f_in:
            lines = f_in.readlines()
        if per_file is not None:
            chosen = random.sample(lines, min(per_file, len(lines)))
        else:
            chosen = lines
        for ln in chosen:
            try:
                rec = json.loads(ln)
                rec['_source_file'] = os.path.basename(fp)
                records.append(rec)
            except json.JSONDecodeError:
                continue  # skip malformed lines

    # If we overshot the limit due to rounding, trim back down
    if total_limit is not None and len(records) > total_limit:
        records = random.sample(records, total_limit)
    return records

def run_gmo_evaluation(args):
    """Main entry for GMO evaluation mode with support for multiple repetition counts (runs_list)."""

    # ---------------------------- Record Sampling ----------------------------
    records = _sample_gmo_records(args.dataset_dir, args.limit)

    # Apply optional slicing using --start and --end (1-based inclusive indices)
    slice_start = max(0, args.start)
    slice_end = args.end if args.end is not None else len(records) - 1
    slice_end = min(slice_end, len(records) - 1)
    records = records[slice_start : slice_end + 1]

    print(
        f"[INFO] Running GMO evaluation on {len(records)} sampled records (slice {slice_start}-{slice_end})."
    )

    # Ensure output directory exists
    os.makedirs(args.output_dir, exist_ok=True)

    # Optionally restrict to NP subset prompts
    if args.use_np_subset:
        subset_prompts = _load_np_subset_prompts(NP_SUBSET_PATH)
        before_filter = len(records)
        records = [rec for rec in records if rec.get('prompt', '') in subset_prompts]
        print(f"[INFO] Using NP subset: {len(records)} of {before_filter} records retained.")

    # ----------------------- Iterate over (possibly filtered) persona prompt sets ------------------
    for persona_label, persona_prompts in PERSONA_SETS:
        print("\n" + "#" * 80)
        print(f"Evaluating persona set: persona_prompts_{persona_label}")
        print("#" * 80)

        # ----------------------- Iterate over repetition counts ------------------
        for n_runs in runs_list:
            print("\n" + "=" * 80)
            print(f"Running evaluation with {n_runs} repetitions per datapoint…")
            print("=" * 80)

            run_results = []

            # Construct output filename that includes persona label and run count
            out_name = (
                f"gmo_persona_results_set{persona_label}_runs{n_runs}_samples{args.limit or 'all'}_{slice_start}_{slice_end}.json"
            )
            out_path = os.path.join(args.output_dir, out_name)

            for rec in tqdm(records, desc=f"GMO Samples x{n_runs}"):
                ad_prompt = rec.get("prompt", "")
                ground_truth = rec.get("response")

                persona_data = {}
                persona_means = []

                for persona_name, persona_sys_prompt in persona_prompts.items():
                    predictions: list[float] = []
                    responses: list[str] = []

                    for _ in range(n_runs):
                        resp_text = _verbalize_persona(ad_prompt, persona_sys_prompt)

                        num_match = re.search(r"(?i)answer[^0-9]{0,10}(\d{1,3}(?:\.\d+)?)", resp_text)
                        score = float(num_match.group(1)) if num_match else None
                        if score is not None:
                            score = max(0.0, min(100.0, score))

                        if score is not None:
                            predictions.append(score)
                        responses.append(resp_text)

                    mean_score_persona = float(np.mean(predictions)) if predictions else None
                    persona_means.append(mean_score_persona if mean_score_persona is not None else np.nan)

                    persona_data[persona_name] = {
                        "predictions": predictions,
                        "mean_prediction": mean_score_persona,
                        "responses": responses,
                    }

                # Overall aggregate across personas (ignore NaN)
                overall_mean = float(np.nanmean(persona_means)) if persona_means else None

                run_results.append(
                    {
                        "prompt": ad_prompt,
                        "ground_truth": ground_truth,
                        "personas": persona_data,
                        "overall_mean_prediction": overall_mean,
                        "source_file": rec.get("_source_file"),
                    }
                )

                # Incremental write after each datapoint to protect against crashes
                try:
                    with open(out_path, "w", encoding="utf-8") as f_inc:
                        json.dump(run_results, f_inc, indent=2)
                except Exception as e:
                    print(f"[WARNING] Incremental save failed: {e}")

            # Final write for this n_runs
            try:
                with open(out_path, "w", encoding="utf-8") as f_final:
                    json.dump(run_results, f_final, indent=2)
            except Exception as e:
                print(f"[ERROR] Final save failed for {out_path}: {e}")

            print(f"[INFO] GMO evaluation with {n_runs} runs complete. Results saved to {out_path}")

# ---------------------------------------------------------------------------
# Short-circuit: run GMO mode and exit (nothing further below).
# ---------------------------------------------------------------------------
if args.gmo:
    run_gmo_evaluation(args)
    sys.exit(0)

# No additional code below this point. 